﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace WorldLab.MathLib
{
    public class LogisticRegression
    {
        /// compute best beta
        static public double[] ComputeBestBeta(Matrix xMatrix, Vector yVector, Vector wVector)
        {
            int maxIterations = 30;
            double epsilon = 0.001;
            double betaDifference = 1000.0;
             
            int xRows = xMatrix.GetNumRows();
            int xCols = xMatrix.GetNumColumns();
           
            double[] bVector = new double[xCols];
            for (int i = 0; i < xCols; ++i)
            {
                bVector[i] = 0.0;
            }
           
            double[] bestBvector = Vector.Duplicate(bVector);
            Vector pVector = ComputeProbVector(xMatrix, bVector);
            for (int i = 0; i < maxIterations; ++i)
            {
                double[] newBvector = ComputeNewBetaVector(bVector, xMatrix, yVector, wVector, 
                    pVector); 
                if (newBvector == null)
                {
                    return bestBvector;
                }
                for (int j = 0; j < newBvector.Length; ++j)
                {
                    if (double.IsNaN(newBvector[j]))
                    {
                        return bestBvector;
                    }
                }
                if (NoChange(bVector, newBvector, epsilon) == true)
                {
                    return newBvector;
                }
                if (OutOfControl(bVector, newBvector, betaDifference) == true) 
                {
                    return bestBvector;
                }

                pVector = ComputeProbVector(xMatrix, newBvector);
                bVector = Vector.Duplicate(newBvector);
                bestBvector = Vector.Duplicate(bVector);        
            }
            return bestBvector;
        } 

        /// compute new beta
        static double[] ComputeNewBetaVector(double[] oldBetaVector, Matrix xMatrix, 
            Vector yVector, Vector wVector, Vector oldProbVector)
        {
            // X'
            Matrix Xt = xMatrix.Transpose();
            // W * V(t-1) * X
            Matrix A = ComputeXwp(oldProbVector, xMatrix, wVector);
            // X'* W * V(t-1) * X
            Matrix B = Xt.Multiply(A);
            // inv(X'* W * V(t-1) * X)
            Matrix C = B.Invert();
            if (C == null)
            {
                return null;
            }
            // inv(X'* W * V(t-1) * X) * X'
            Matrix D = C.Multiply(Xt);
            // W * (y - p(t-1))
            Vector E = ComputeYwp(oldProbVector, yVector, wVector);
            // inv(X'* W * V(t-1) * X) * X' * W * (y - p(t-1))
            Vector F = D.Multiply(E);
            double[] result = new double[oldBetaVector.Length];
            for (int i = 0; i < oldBetaVector.Length; ++i )
            {
                result[i] = oldBetaVector[i] + F.GetElement(i);
            }
            return result;
        }

        /// compute W(u) * V(u) * X
        static Matrix ComputeXwp(Vector pVector, Matrix xMatrix, Vector wVector)
        {
            int pRows = pVector.GetNums();
            int xRows = xMatrix.GetNumRows();
            int xCols = xMatrix.GetNumColumns();
            if (pRows != xRows)
            {
                throw new Exception("Vector P does not match matrix X.");
            }
            Matrix result = new Matrix(pRows, xCols);
            for (int i = 0; i < pRows; ++i)
            {
                for (int j = 0; j < xCols; ++j)
                {
                    result.SetElement(i, j, wVector.GetElement(i) * pVector.GetElement(i) * 
                        (1.0 - pVector.GetElement(i)) * xMatrix.GetElement(i, j));
                }
            }
            return result;
        }

        /// compute W(u) * (y - p)
        static Vector ComputeYwp(Vector pVector, Vector yVector, Vector wVector)
        {
            int pNums = pVector.GetNums();
            int yNums = yVector.GetNums();
            if (pNums != yNums)
            {
                throw new Exception("Vector P does not match vector Y.");
            }
            // y - p(t-1)
            Vector result = yVector - pVector;
            // W * (y - p(t-1)
            for (int i = 0; i < pNums; i++)
            {
                result.SetElement(i, wVector.GetElement(i) * result.GetElement(i));
            }

            return result;
        }

        /// compute vector p
        static Vector ComputeProbVector(Matrix xMatrix, double[] bVector)
        {
            int xRows = xMatrix.GetNumRows();
            int xCols = xMatrix.GetNumColumns();
            int bRows = bVector.Length;
            Vector result = new Vector(xRows);
            double z = 0.0;
            double p = 0.0;
            for (int i = 0; i < xRows; ++i)
            {
                z = 0.0;
                for (int j = 0; j < xCols; ++j)
                {
                    z += xMatrix.GetElement(i, j) * bVector[j];
                }
                p = 1.0 / (1.0 + Math.Exp(-z));
                result.SetElement(i, p);
            }
            return result;
        }

        /// if the new beta is equal to the old one under certain accuracy
        static bool NoChange(double[] oldBeta, double[] newBeta, double eps)
        {
            for (int i = 0; i < oldBeta.Length; ++i)
            {
                if (Math.Abs(oldBeta[i] - newBeta[i]) > eps)
                    return false;
            }
            return true;
        }

        /// if the new beta has obvious difference with the old one 
        static bool OutOfControl(double[] oldBeta, double[] newBeta, double betaDifference)
        {
            for (int i = 0; i < oldBeta.Length; ++i)
            {
                if (oldBeta[i] == 0.0)
                {
                    return false;
                }

                if (Math.Abs(oldBeta[i] - newBeta[i]) / Math.Abs(oldBeta[i])
                    > betaDifference)
                {
                    return true;
                }
            }
            return false;
        }
    }
}


    